Skip to main content

Dot Product between Two Tensors Across the First Axis

This is an example of how to use the csdl.dot() function to compute the dot product between two vectors.

from csdl_om import Simulatorfrom csdl import Modelimport csdlimport numpy as np

class ExampleTensorTensorFirst(Model):
    def define(self):
        m = 3        n = 4        p = 5
        # Shape of the tensors        ten_shape = (m, n, p)
        # Number of elements in the tensors        num_ten_elements = np.prod(ten_shape)
        # Values for the two tensors        ten1 = np.arange(num_ten_elements).reshape(ten_shape)        ten2 = np.arange(num_ten_elements,                         2 * num_ten_elements).reshape(ten_shape)
        # Adding the tensors to csdl        ten1 = self.declare_variable('ten1', val=ten1)        ten2 = self.declare_variable('ten2', val=ten2)
        # Tensor-Tensor Dot Product specifying the first axis        self.register_output('TenTenDotFirst',                             csdl.dot(ten1, ten2, axis=0))

sim = Simulator(ExampleTensorTensorFirst())sim.run()
print('ten1', sim['ten1'].shape)print(sim['ten1'])print('ten2', sim['ten2'].shape)print(sim['ten2'])print('TenTenDotFirst', sim['TenTenDotFirst'].shape)print(sim['TenTenDotFirst'])
[[[ 0.  1.  2.  3.  4.]  [ 5.  6.  7.  8.  9.]  [10. 11. 12. 13. 14.]  [15. 16. 17. 18. 19.]]
 [[20. 21. 22. 23. 24.]  [25. 26. 27. 28. 29.]  [30. 31. 32. 33. 34.]  [35. 36. 37. 38. 39.]]
 [[40. 41. 42. 43. 44.]  [45. 46. 47. 48. 49.]  [50. 51. 52. 53. 54.]  [55. 56. 57. 58. 59.]]]ten2 (3, 4, 5)[[[ 60.  61.  62.  63.  64.]  [ 65.  66.  67.  68.  69.]  [ 70.  71.  72.  73.  74.]  [ 75.  76.  77.  78.  79.]]
 [[ 80.  81.  82.  83.  84.]  [ 85.  86.  87.  88.  89.]  [ 90.  91.  92.  93.  94.]  [ 95.  96.  97.  98.  99.]]
 [[100. 101. 102. 103. 104.]  [105. 106. 107. 108. 109.]  [110. 111. 112. 113. 114.]  [115. 116. 117. 118. 119.]]]TenTenDotFirst (4, 5)[[ 5600.  5903.  6212.  6527.  6848.] [ 7175.  7508.  7847.  8192.  8543.] [ 8900.  9263.  9632. 10007. 10388.] [10775. 11168. 11567. 11972. 12383.]]